package cn.jihhhh.chatgpt.test;

import cn.jihhhh.chatgpt.common.Constants;
import cn.jihhhh.chatgpt.domain.chat.ChatCompletionRequest;
import cn.jihhhh.chatgpt.domain.chat.Message;
import cn.jihhhh.chatgpt.session.Configuration;
import cn.jihhhh.chatgpt.session.OpenAiSession;
import cn.jihhhh.chatgpt.session.OpenAiSessionFactory;
import cn.jihhhh.chatgpt.session.defaults.DefaultOpenAiSessionFactory;
import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import org.junit.Before;
import org.junit.Test;

import java.util.Collections;
import java.util.concurrent.CountDownLatch;

/**
 * 单元测试
 *
 * @author JIHHHH
 * @date 2025-01-13
 */
@Slf4j
public class ApiTest {
    private OpenAiSession openAiSession;

    @Before
    public void test_OpenAiSessionFactory() {
        // 1 官网原始 apiHost https://api.openai.com/ - 官网的Key可直接使用
        Configuration configuration = new Configuration();
        configuration.setApiHost("https://pro-share-aws-api.zcyai.com/");
        configuration.setApiKey("sk-b0A0eSKTNxgBqrHv7aAa0808EdB849C89499D928648bD416");
        // 2. 会话工厂
        OpenAiSessionFactory factory = new DefaultOpenAiSessionFactory(configuration);
        // 3. 开启会话
        this.openAiSession = factory.openSession();
    }

    /**
     * 常用对话模式
     */
    @Test
    public void test_chat_completions_stream_channel() throws JsonProcessingException, InterruptedException {
        // 1. 创建参数
        ChatCompletionRequest chatCompletion = ChatCompletionRequest
                .builder()
                .stream(true)
                .messages(Collections.singletonList(Message.builder().role(Constants.Role.USER).content("1+1").build()))
                .model(ChatCompletionRequest.Model.GPT_3_5_TURBO.getCode())
                .maxTokens(1024)
                .build();

        // 2. 可以在此处直接配置自己的模型, 也可以为空，则使用全局的
        String apiHost = "https://ip/";
        String apiKey = "sk-xxxx";

        // 3. 发起请求
        EventSource eventSource = openAiSession.chatCompletions(apiHost, apiKey, chatCompletion, new EventSourceListener() {
            @Override
            public void onEvent(EventSource eventSource, String id, String type, String data) {
                log.info("测试结果 id:{} type:{} data:{}", id, type, data);
            }

            @Override
            public void onFailure(EventSource eventSource, Throwable t, Response response) {
                log.error("失败 code:{} message:{}", response.code(), response.message());
            }
        });
        // 等待
        new CountDownLatch(1).await();
    }
}
